{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Theano 循环：scan（详解）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using gpu device 1: Tesla C2075 (CNMeM is disabled)\n"
     ]
    }
   ],
   "source": [
    "import theano, time\n",
    "import theano.tensor as T\n",
    "import numpy as np\n",
    "\n",
    "def floatX(X):\n",
    "    return np.asarray(X, dtype=theano.config.floatX)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`theano` 中可以使用 `scan` 进行循环，常用的 `map` 和 `reduce` 操作都可以看成是 `scan` 的特例。 \n",
    "\n",
    "`scan` 通常作用在一个序列上，每次处理一个输入，并输出一个结果。\n",
    "\n",
    "`sum(x)` 函数可以看成是 `z + x(i)` 函数在给定 `z = 0` 的情况下，对 `x` 的一个 `scan`。\n",
    "\n",
    "通常我们可以将一个 `for` 循环表示成一个 `scan` 操作，其好处如下：\n",
    "\n",
    "- 迭代次数成为符号图结构的一部分\n",
    "- 最小化 GPU 数据传递\n",
    "- 序列化梯度计算\n",
    "- 速度比 `for` 稍微快一些\n",
    "- 降低内存使用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## scan 的使用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "函数的用法如下：\n",
    "\n",
    "    theano.scan(fn, \n",
    "                sequences=None, \n",
    "                outputs_info=None, \n",
    "                non_sequences=None, \n",
    "                n_steps=None, \n",
    "                truncate_gradient=-1, \n",
    "                go_backwards=False, \n",
    "                mode=None, \n",
    "                name=None, \n",
    "                profile=False, \n",
    "                allow_gc=None, \n",
    "                strict=False)\n",
    "                \n",
    "主要参数的含义：\n",
    "\n",
    "- `fn`\n",
    "    - 一步 `scan` 所进行的操作\n",
    "- `sequences`\n",
    "    - 输入的序列\n",
    "- `outputs_info`\n",
    "    - 前一步输出结果的初始状态\n",
    "- `non_sequences`\n",
    "    - 非序列参数\n",
    "- `n_steps`\n",
    "    - 迭代步数\n",
    "- `go_backwards`\n",
    "    - 是否从后向前遍历\n",
    "    \n",
    "输出为一个元组 `(outputs, updates)`：\n",
    "\n",
    "- `outputs`\n",
    "    - 从初始状态开始，每一步 `fn` 的输出结果\n",
    "- `updates`\n",
    "    - 一个字典，用来记录 `scan` 过程中用到的共享变量更新规则，构造函数的时候，如果需要更新共享变量，将这个变量当作 `updates` 的参数传入。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "##  scan 和 map"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里实现一个简单的 `map` 操作，将向量 $\\mathbf x$ 中的所有元素变成原来的两倍：\n",
    "\n",
    "```python\n",
    "map(lambda t: t * 2, x)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[  0.   2.   4.   6.   8.  10.  12.  14.  16.  18.]\n"
     ]
    }
   ],
   "source": [
    "x = T.vector()\n",
    "\n",
    "results, _ = theano.scan(fn = lambda t: t * 2,\n",
    "                         sequences = x)\n",
    "x_double_scan = theano.function([x], results)\n",
    "\n",
    "print x_double_scan(range(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "之前我们说到，`theano` 中的 `map` 是 `scan` 的一个特例，因此 `theano.map` 的用法其实跟 `theano.scan` 十分类似。\n",
    "\n",
    "由于不需要考虑前一步的输出结果，所以 `theano.map` 的参数中没有 `outputs_info` 这一部分。\n",
    "\n",
    "我们用 `theano.map` 实现相同的效果："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[  0.   2.   4.   6.   8.  10.  12.  14.  16.  18.]\n"
     ]
    }
   ],
   "source": [
    "result, _ = theano.map(fn = lambda t: t * 2,\n",
    "                       sequences = x)\n",
    "x_double_map = theano.function([x], result)\n",
    "\n",
    "print x_double_map(range(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## scan 和 reduce"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这里一个简单的 `reduce` 操作，求和：\n",
    "\n",
    "```python\n",
    "reduce(lambda a, b: a + b, x)\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "45.0\n"
     ]
    }
   ],
   "source": [
    "result, _ = theano.scan(fn = lambda t, v: t + v,\n",
    "                        sequences = x,\n",
    "                        outputs_info = floatX(0.))\n",
    "\n",
    "# 因为每一步的输出值都会被记录到最后的 result 中，所以最后的和是 result 的最后一个元素。\n",
    "x_sum_scan = theano.function([x], result[-1])\n",
    "\n",
    "# 计算 1 + 2 + ... + 10\n",
    "print x_sum_scan(range(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`theano.reduce` 也是 `scan` 的一个特例，使用 `theano.reduce` 实现相同的效果： "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "45.0\n"
     ]
    }
   ],
   "source": [
    "result, _ = theano.reduce(fn = lambda t, v: t + v,\n",
    "                          sequences = x,\n",
    "                          outputs_info = 0.)\n",
    "\n",
    "x_sum_reduce = theano.function([x], result)\n",
    "\n",
    "# 计算 1 + 2 + ... + 10\n",
    "print x_sum_reduce(range(10))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`reduce` 与 `scan` 不同的地方在于，`result` 包含的内容并不是每次输出的结果，而是最后一次输出的结果。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## scan 的使用"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 输入与输出"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`fn` 是一个函数句柄，对于这个函数句柄，它每一步接受的参数是由 `sequences, outputs_info, non_sequence` 这三个参数所决定的，并且按照以下的顺序排列：\n",
    "\n",
    "- `sequences` 中第一个序列的值\n",
    "- ...\n",
    "- `sequences` 中最后一个序列的值\n",
    "- `outputs_info` 中第一个输出之前的值\n",
    "- ...\n",
    "- `outputs_info` 中最后一个输出之前的值\n",
    "- `non_squences` 中的参数\n",
    "\n",
    "这些序列的顺序与在参数 `sequences, outputs_info` 中指定的顺序相同。\n",
    "\n",
    "默认情况下，在第 `k` 次迭代时，如果 `sequences` 和 `outputs_info` 中给定的值不是字典（`dictionary`）或者一个字典列表（`list of dictionaries`），那么 \n",
    "\n",
    "- `sequences` 中的序列 `seq` 传入 `fn` 的是 `seq[k]` 的值\n",
    "- `outputs_info` 中的序列 `output` 传入 `fn` 的是 `output[k-1]` 的值\n",
    "\n",
    "`fn` 的返回值有两部分 `(outputs_list, update_dictionary)`，第一部分将作为序列，传入 `outputs` 中，与 `outputs_info` 中的**初始输入值的维度一致**（如果没有给定 `outputs_info` ，输出值可以任意。）\n",
    "\n",
    "第二部分则是更新规则的字典，告诉我们如何对 `scan` 中使用到的一些共享的变量进行更新：\n",
    "```python\n",
    "return [y1_t, y2_t], {x:x+1}\n",
    "```\n",
    "\n",
    "这两部分可以任意，即顺序既可以是 `(outputs_list, update_dictionary)`, 也可以是 `(update_dictionary, outputs_list)`，`theano` 会根据类型自动识别。\n",
    "\n",
    "两部分只需要有一个存在即可，另一个可以为空。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 例子分析"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "例如，在我们的第一个例子中\n",
    "\n",
    "```python\n",
    "theano.scan(fn = lambda t: t * 2,\n",
    "            sequences = x)\n",
    "```\n",
    "\n",
    "在第 `k` 次迭代的时候，传入参数 `t` 的值为 `x[k]`。\n",
    "\n",
    "再如，在我们的第二个例子中：\n",
    "\n",
    "```python\n",
    "theano.scan(fn = lambda t, v: t + v,\n",
    "            sequences = x,\n",
    "            outputs_info = floatX(0.))\n",
    "```\n",
    "\n",
    "`fn` 接受了两个参数，初始迭代时，按照规则，`t` 接受的参数为 `x[0]`，`v` 接受的参数为我们传入 `outputs_info` 的第一个初始值即 `0` （认为是 `outputs[-1]`），他们的结果 `t+v` 将作为 `outputs[0]` 的值传入下一次迭代以及最终 `scan` 输出的 `outputs` 值中。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 输入多个序列"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以一次输入多个序列，这些序列会按照顺序传入 fn 的参数中，例如计算多项式\n",
    "$$\n",
    "\\sum_{n=0}^N a_n x^ n\n",
    "$$\n",
    "时，我们可以将多项式的系数和幂数两个序列放到一个 `list` 中作为输入参数："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2301.0\n"
     ]
    }
   ],
   "source": [
    "# 变量 x\n",
    "x = T.scalar(\"x\")\n",
    "\n",
    "# 不为 0 的系数\n",
    "A = T.vectors(\"A\")\n",
    "\n",
    "# 对应的幂数\n",
    "N = T.ivectors(\"N\")\n",
    "\n",
    "# a 对应的是 A， n 对应 N，v 对应 x\n",
    "components, _ = theano.scan(fn = lambda a, n, v: a * (v ** n),\n",
    "                            sequences = [A, N],\n",
    "                            non_sequences = x)\n",
    "\n",
    "result = components.sum()\n",
    "\n",
    "polynomial = theano.function([x, A, N], result)\n",
    "\n",
    "# 计算 1 + 3 * 10 ^ 2 + 2 * 10^3 = 2301\n",
    "print polynomial(floatX(10), \n",
    "                 floatX([1, 3, 2]),\n",
    "                 [0, 2, 3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用序列的多个值"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "默认情况下，我们只能使用输入序列的当前时刻的值，以及前一个输出的输出值。\n",
    "\n",
    "事实上，`theano` 会将参数中的序列变成一个有 `input` 和 `taps` 两个键值的 `dict`：\n",
    "\n",
    "- `input`：输入的序列\n",
    "- `taps`：要传入 `fn` 的值的列表\n",
    "    - 对于 `sequences` 参数中的序列来说，默认值为 [0]，表示时间 `t` 传入 `t+0` 时刻的序列值，可以为正，可以为负。\n",
    "    - 对于 `outputs_info` 参数中的序列来说，默认值为 [-1]，表示时间 `t` 传入 `t-1` 时刻的序列值，只能为负值，如果值为 `None`，表示这个输出结果不会作为参数传入 `fn` 中。\n",
    "    \n",
    "传入 `fn` 的参数也会按照 `taps` 中的顺序来排列，我们考虑下面这个例子：\n",
    "```python\n",
    "scan(fn, sequences = [ dict(input= Sequence1, taps = [-3,2,-1])\n",
    "                     , Sequence2\n",
    "                     , dict(input =  Sequence3, taps = 3) ]\n",
    "       , outputs_info = [ dict(initial =  Output1, taps = [-3,-5])\n",
    "                        , dict(initial = Output2, taps = None)\n",
    "                        , Output3 ]\n",
    "       , non_sequences = [ Argument1, Argument2])\n",
    "```\n",
    "首先是 `Sequence1` 的 `[-3, 2, -1]` 被传入，然后 `Sequence2` 不是 `dict`， 所以传入默认值 `[0]`，`Sequence3` 传入的参数是 `3`，所以 `fn` 在第 `t` 步接受的前几个参数是：\n",
    "```\n",
    "Sequence1[t-3]\n",
    "Sequence1[t+2]\n",
    "Sequence1[t-1]\n",
    "Sequence2[t]\n",
    "Sequence3[t+3]\n",
    "```\n",
    "\n",
    "然后 `Output1` 传入的是 `[-3, -5]`（**传入的初始值的形状应为 `shape (5,)+`**），`Output2` 不作为参数传入，`Output3` 传入的是 `[-1]`，所以接下的参数是：\n",
    "```\n",
    "Output1[t-3]\n",
    "Output1[t-5]\n",
    "Output3[t-1]\n",
    "Argument1\n",
    "Argument2\n",
    "```\n",
    "\n",
    "总的说来上面的例子中，`fn` 函数按照以下顺序最多接受这样 10 个参数：\n",
    "```\n",
    "Sequence1[t-3]\n",
    "Sequence1[t+2]\n",
    "Sequence1[t-1]\n",
    "Sequence2[t]\n",
    "Sequence3[t+3]\n",
    "Output1[t-3]\n",
    "Output1[t-5]\n",
    "Output3[t-1]\n",
    "Argument1\n",
    "Argument2\n",
    "```\n",
    "\n",
    "例子，假设 $x$ 是我们的输入，$y$ 是我们的输出，我们需要计算 $y(t) = tanh\\left[W_{1} y(t-1) + W_{2} x(t) + W_{3} x(t-1)\\right]$ 的值："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "X = T.matrix(\"X\")\n",
    "Y = T.vector(\"y\")\n",
    "\n",
    "W_1 = T.matrix(\"W_1\")\n",
    "W_2 = T.matrix(\"W_2\")\n",
    "W_3 = T.matrix(\"W_3\")\n",
    "\n",
    "# W_yy 和 W_xy 作为不变的参数可以直接使用\n",
    "results, _ = theano.scan(fn = lambda x, x_pre, y: T.tanh(T.dot(W_1, y) + T.dot(W_2, x) + T.dot(W_3, x_pre)), \n",
    "                         # 0 对应 x，-1 对应 x_pre\n",
    "                         sequences = dict(input=X, taps=[0, -1]), \n",
    "                         outputs_info = Y)\n",
    "\n",
    "Y_seq = theano.function(inputs = [X, Y, W_1, W_2, W_3], \n",
    "                        outputs = results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试小矩阵计算："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "theano running time 0.0537 s\n",
      "numpy  running time 0.0197 s\n",
      "the max difference of the first 10 results is 1.25780650354e-06\n"
     ]
    }
   ],
   "source": [
    "# 测试\n",
    "t = 1001\n",
    "x_dim = 10\n",
    "y_dim = 20\n",
    "\n",
    "x = 2 * floatX(np.random.random([t, x_dim])) - 1\n",
    "y = 2 * floatX(np.zeros(y_dim)) - 1\n",
    "w_1 = 2 * floatX(np.random.random([y_dim, y_dim])) - 1\n",
    "w_2 = 2 * floatX(np.random.random([y_dim, x_dim])) - 1\n",
    "w_3 = 2 * floatX(np.random.random([y_dim, x_dim])) - 1\n",
    "\n",
    "tic = time.time()\n",
    "\n",
    "y_res_theano = Y_seq(x, y, w_1, w_2, w_3)\n",
    "\n",
    "print \"theano running time {:.4f} s\".format(time.time() - tic)\n",
    "\n",
    "tic = time.time()\n",
    "# 与 numpy 的结果进行比较：\n",
    "y_res_numpy = np.zeros([t, y_dim])\n",
    "y_res_numpy[0] = y\n",
    "\n",
    "for i in range(1, t):\n",
    "    y_res_numpy[i] = np.tanh(w_1.dot(y_res_numpy[i-1]) + w_2.dot(x[i]) + w_3.dot(x[i-1]))\n",
    "\n",
    "print \"numpy  running time {:.4f} s\".format(time.time() - tic)\n",
    "\n",
    "# 这里要从 1 开始，因为使用了 x(t-1)，所以 scan 从第 1 个位置开始计算\n",
    "print \"the max difference of the first 10 results is\", np.max(np.abs(y_res_theano[0:10] - y_res_numpy[1:11]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试大矩阵运算："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "theano running time 0.0754 s\n",
      "numpy  running time 0.1334 s\n",
      "the max difference of the first 10 results is 0.000656997077348\n"
     ]
    }
   ],
   "source": [
    "# 测试\n",
    "t = 1001\n",
    "x_dim = 100\n",
    "y_dim = 200\n",
    "\n",
    "x = 2 * floatX(np.random.random([t, x_dim])) - 1\n",
    "y = 2 * floatX(np.zeros(y_dim)) - 1\n",
    "w_1 = 2 * floatX(np.random.random([y_dim, y_dim])) - 1\n",
    "w_2 = 2 * floatX(np.random.random([y_dim, x_dim])) - 1\n",
    "w_3 = 2 * floatX(np.random.random([y_dim, x_dim])) - 1\n",
    "\n",
    "tic = time.time()\n",
    "\n",
    "y_res_theano = Y_seq(x, y, w_1, w_2, w_3)\n",
    "\n",
    "print \"theano running time {:.4f} s\".format(time.time() - tic)\n",
    "\n",
    "tic = time.time()\n",
    "# 与 numpy 的结果进行比较：\n",
    "y_res_numpy = np.zeros([t, y_dim])\n",
    "y_res_numpy[0] = y\n",
    "\n",
    "for i in range(1, t):\n",
    "    y_res_numpy[i] = np.tanh(w_1.dot(y_res_numpy[i-1]) + w_2.dot(x[i]) + w_3.dot(x[i-1]))\n",
    "\n",
    "print \"numpy  running time {:.4f} s\".format(time.time() - tic)\n",
    "\n",
    "# 这里要从 1 开始，因为使用了 x(t-1)，所以 scan 从第 1 个位置开始计算\n",
    "print \"the max difference of the first 10 results is\", np.max(np.abs(y_res_theano[:10] - y_res_numpy[1:11]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "值得注意的是，由于 `theano` 和 `numpy` 在某些计算的实现上存在一定的差异，随着序列长度的增加，这些差异将被放大："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "iter 001, max diff:0.000002\n",
      "iter 002, max diff:0.000005\n",
      "iter 003, max diff:0.000007\n",
      "iter 004, max diff:0.000010\n",
      "iter 005, max diff:0.000024\n",
      "iter 006, max diff:0.000049\n",
      "iter 007, max diff:0.000113\n",
      "iter 008, max diff:0.000145\n",
      "iter 009, max diff:0.000334\n",
      "iter 010, max diff:0.000657\n",
      "iter 011, max diff:0.001195\n",
      "iter 012, max diff:0.002778\n",
      "iter 013, max diff:0.004561\n",
      "iter 014, max diff:0.004748\n",
      "iter 015, max diff:0.014849\n",
      "iter 016, max diff:0.012696\n",
      "iter 017, max diff:0.043639\n",
      "iter 018, max diff:0.046540\n",
      "iter 019, max diff:0.083032\n",
      "iter 020, max diff:0.123678\n"
     ]
    }
   ],
   "source": [
    "for i in xrange(20):\n",
    "    print \"iter {:03d}, max diff:{:.6f}\".format(i + 1, \n",
    "                                                np.max(np.abs(y_res_numpy[i + 1,:] - y_res_theano[i,:])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 控制循环次数"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "假设我们要计算方阵$A$的$A^k$，$k$ 是一个未知变量，我们可以这样通过 `n_steps` 参数来控制循环计算的次数： "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[ 107616. -152192.]\n",
      " [ -76096.  107616.]]\n",
      "[[ 107616. -152192.]\n",
      " [ -76096.  107616.]]\n"
     ]
    }
   ],
   "source": [
    "A = T.matrix(\"A\")\n",
    "k = T.iscalar(\"k\")\n",
    "\n",
    "results, _ = theano.scan(fn = lambda P, A: P.dot(A),\n",
    "                         # 初始值设为单位矩阵\n",
    "                         outputs_info = T.eye(A.shape[0]),\n",
    "                         # 乘 k 次\n",
    "                         non_sequences = A,\n",
    "                         n_steps = k)\n",
    "\n",
    "A_k = theano.function(inputs = [A, k], outputs = results[-1])\n",
    "\n",
    "test_a = floatX([[2, -2], [-1, 2]])\n",
    "\n",
    "print A_k(test_a, 10)\n",
    "\n",
    "# 使用 numpy 进行验证\n",
    "a_k = np.eye(2)\n",
    "for i in range(10):\n",
    "    a_k = a_k.dot(test_a)\n",
    "    \n",
    "print a_k"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用共享变量"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "可以在 `scan` 中使用并更新共享变量，例如，利用共享变量 `n`，我们可以实现这样一个迭代 `k` 步的简单计数器："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0\n",
      "10.0\n",
      "20.0\n"
     ]
    }
   ],
   "source": [
    "n = theano.shared(floatX(0))\n",
    "k = T.iscalar(\"k\")\n",
    "\n",
    "# 这里 lambda 的返回值是一个 dict，因此这个值会被传入 updates 中\n",
    "_, updates = theano.scan(fn = lambda n: {n:n+1},\n",
    "                         non_sequences = n,\n",
    "                         n_steps = k)\n",
    "\n",
    "counter = theano.function(inputs = [k], \n",
    "                          outputs = [],\n",
    "                          updates = updates)\n",
    "\n",
    "print n.get_value()\n",
    "counter(10)\n",
    "print n.get_value()\n",
    "counter(10)\n",
    "print n.get_value()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "之前说到，`fn` 函数的返回值应该是 `(outputs_list, update_dictionary)` 或者 `(update_dictionary, outputs_list)` 或者两者之一。\n",
    "\n",
    "这里 `fn` 函数返回的是一个字典，因此自动被放入了 `update_dictionary` 中，然后传入 `function` 的 `updates` 参数中进行迭代。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 使用条件语句结束循环"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以将 `scan` 设计为 `loop-until` 的模式，具体方法是在 `scan` 中，将 `fn` 的返回值增加一个参数，使用 `theano.scan_module` 来设置停止条件。\n",
    "\n",
    "假设我们要计算所有不小于某个值的 2 的幂，我们可以这样定义："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[  2.   4.   8.  16.  32.]\n"
     ]
    }
   ],
   "source": [
    "max_value = T.scalar()\n",
    "\n",
    "results, _ = theano.scan(fn = lambda v_pre, max_v: (v_pre * 2, theano.scan_module.until(v_pre * 2 > max_v)), \n",
    "                         outputs_info = T.constant(1.),\n",
    "                         non_sequences = max_value,\n",
    "                         n_steps = 1000)\n",
    "\n",
    "# 注意，这里不能取 results 的全部\n",
    "# 例如在输入值为 40 时，最后的输出可以看成 (64, False)\n",
    "# scan 发现停止条件满足，停止循环，但是不影响 64 被输出到 results 中，因此要将 64 去掉\n",
    "power_of_2 = theano.function(inputs = [max_value], outputs = results[:-1])\n",
    "\n",
    "print power_of_2(40)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 2",
   "language": "python",
   "name": "python2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
